{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "torch.Size([4, 23, 8, 8])\n"
    }
   ],
   "source": [
    "import time\n",
    "import torch\n",
    "import torch.nn as nn \n",
    "import torch.nn.functional as F \n",
    "import d2lzh as d2l\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "def conv_block(in_channels,out_channels):\n",
    "    blk = nn.Sequential(\n",
    "        nn.BatchNorm2d(in_channels),\n",
    "        nn.ReLU(),\n",
    "        nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1)\n",
    "    )\n",
    "    return blk\n",
    "\n",
    "class DenseBlock(nn.Module):\n",
    "    def __init__(self,num_convs,in_channels,out_channels):\n",
    "        super(DenseBlock,self).__init__()\n",
    "        net = []\n",
    "        for i in range(num_convs):\n",
    "            in_c = in_channels + i * out_channels\n",
    "            net.append(conv_block(in_c,out_channels))\n",
    "        self.net = nn.ModuleList(net)\n",
    "        self.out_channels = in_channels + num_convs * out_channels\n",
    "\n",
    "    def forward(self,x):\n",
    "        for blk in self.net:\n",
    "            y = blk(x)\n",
    "            x = torch.cat((x,y),dim=1) #在通道维度上将输入和输出连接\n",
    "        return x\n",
    "\n",
    "blk = DenseBlock(2,3,10)\n",
    "x = torch.randn(4,3,8,8)\n",
    "y = blk(x)\n",
    "print(y.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "torch.Size([4, 10, 4, 4])"
     },
     "metadata": {},
     "execution_count": 2
    }
   ],
   "source": [
    "def transition_block(in_channels,out_channels):\n",
    "    blk= nn.Sequential(\n",
    "        nn.BatchNorm2d(in_channels),\n",
    "        nn.ReLU(),\n",
    "        nn.Conv2d(in_channels,out_channels,kernel_size=1),\n",
    "        nn.AvgPool2d(kernel_size=2,stride=2)\n",
    "    )\n",
    "    return blk\n",
    "\n",
    "blk = transition_block(23,10)\n",
    "blk(y).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "0  output shape:\t torch.Size([1, 64, 48, 48])\n1  output shape:\t torch.Size([1, 64, 48, 48])\n2  output shape:\t torch.Size([1, 64, 48, 48])\n3  output shape:\t torch.Size([1, 64, 24, 24])\nDenseBlock_0  output shape:\t torch.Size([1, 192, 24, 24])\ntransition_block_0  output shape:\t torch.Size([1, 96, 12, 12])\nDenseBlock_1  output shape:\t torch.Size([1, 224, 12, 12])\ntransition_block_1  output shape:\t torch.Size([1, 112, 6, 6])\nDenseBlock_2  output shape:\t torch.Size([1, 240, 6, 6])\ntransition_block_2  output shape:\t torch.Size([1, 120, 3, 3])\nDenseBlock_3  output shape:\t torch.Size([1, 248, 3, 3])\nBN  output shape:\t torch.Size([1, 248, 3, 3])\nrelu  output shape:\t torch.Size([1, 248, 3, 3])\nglobal_avg_pool  output shape:\t torch.Size([1, 248, 1, 1])\nfc  output shape:\t torch.Size([1, 10])\n"
    }
   ],
   "source": [
    "net = nn.Sequential(\n",
    "    nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),\n",
    "    nn.BatchNorm2d(64),\n",
    "    nn.ReLU(),\n",
    "    nn.MaxPool2d(kernel_size=3,stride=2,padding=1)\n",
    ")\n",
    "\n",
    "num_channels,growth_rate=64,32\n",
    "num_convs_in_dense_blocks = [4,4,4,4]\n",
    "\n",
    "for i,num_convs in enumerate(num_convs_in_dense_blocks):\n",
    "    db = DenseBlock(num_convs,num_channels,growth_rate)\n",
    "    net.add_module(\"DenseBlock_%d\" %i,db)\n",
    "    num_channels = db.out_channels\n",
    "    if i != len(num_convs_in_dense_blocks)-1: # 在稠密块之间加入通道数量减半的过渡层\n",
    "        net.add_module(\"transition_block_{}\".format(i),transition_block(num_channels,num_channels//2))\n",
    "        num_channels = num_channels // 2\n",
    "\n",
    "net.add_module(\"BN\",nn.BatchNorm2d(num_channels))\n",
    "net.add_module(\"relu\",nn.ReLU())\n",
    "net.add_module(\"global_avg_pool\",d2l.GlobalAvgPool2d())\n",
    "net.add_module(\"fc\",nn.Sequential(d2l.FlattenLayer(),nn.Linear(num_channels,10)))\n",
    "\n",
    "\n",
    "x = torch.rand(1,1,96,96)\n",
    "for name,layer in net.named_children():\n",
    "    x = layer(x)\n",
    "    print(name,' output shape:\\t',x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "training on  cuda\nepoch 1, loss 0.4464, train acc 0.841,test acc 0.866,time 50.3\nepoch 2, loss 0.1359, train acc 0.901,test acc 0.886,time 48.9\nepoch 3, loss 0.0775, train acc 0.914,test acc 0.901,time 49.1\nepoch 4, loss 0.0520, train acc 0.924,test acc 0.919,time 49.1\nepoch 5, loss 0.0374, train acc 0.930,test acc 0.844,time 49.1\n"
    }
   ],
   "source": [
    "batch_size = 64 # 如出现“out of memory”的报错信息，可减⼩小batch_size或resize \n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)\n",
    "lr, num_epochs = 0.001, 5 \n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python37764bitpytorchnotebookconda6e7a8693df0d4d92aca09d521275d23a",
   "display_name": "Python 3.7.7 64-bit ('pytorch_notebook': conda)"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}